{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "1aeec27a",
   "metadata": {},
   "source": [
    "# Assess predictions on binary text classification blbooksgenre data with a huggingface transformers model"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7b3a8f67",
   "metadata": {},
   "source": [
    "This notebook demonstrates the use of the `responsibleai` API to assess a text classification huggingface transformers model trained on the blbooksgenre dataset (see https://huggingface.co/datasets/blbooksgenre for more information about the dataset). It walks through the API calls necessary to create a widget with model analysis insights, then guides a visual analysis of the model."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c8051a88",
   "metadata": {},
   "source": [
    "* [Launch Responsible AI Toolbox](#Launch-Responsible-AI-Toolbox)\n",
    "    * [Load Model and Data](#Load-Model-and-Data)\n",
    "    * [Create Model and Data Insights](#Create-Model-and-Data-Insights)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4dae1786",
   "metadata": {},
   "source": [
    "## Launch Responsible AI Toolbox"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "501eb849",
   "metadata": {},
   "source": [
    "The following section examines the code necessary to create datasets and a model. It then generates insights using the `responsibleai` API that can be visually analyzed."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "792b30e9",
   "metadata": {},
   "source": [
    "### Load Model and Data\n",
    "*The following section can be skipped. It loads a dataset and trains a model for illustrative purposes.*"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9b0aa61d",
   "metadata": {},
   "source": [
    "First we import all necessary dependencies"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "75ef9e91",
   "metadata": {},
   "outputs": [],
   "source": [
    "import datasets\n",
    "import pandas as pd\n",
    "import zipfile\n",
    "from sklearn.model_selection import train_test_split\n",
    "from transformers import (AutoModelForSequenceClassification, AutoTokenizer,\n",
    "                          pipeline)\n",
    "\n",
    "from raiutils.common.retries import retry_function\n",
    "\n",
    "try:\n",
    "    from urllib import urlretrieve\n",
    "except ImportError:\n",
    "    from urllib.request import urlretrieve"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9bd63f09",
   "metadata": {},
   "source": [
    "Next we load the blbooksgenre dataset from huggingface datasets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3ae1bf09",
   "metadata": {},
   "outputs": [],
   "source": [
    "NUM_TEST_SAMPLES = 20\n",
    "\n",
    "def load_dataset(split):\n",
    "    config_kwargs = {\"name\": \"title_genre_classifiction\"}\n",
    "    dataset = datasets.load_dataset(\"blbooksgenre\", split=split, trust_remote_code=True, **config_kwargs)\n",
    "    return pd.DataFrame({\"text\": dataset[\"title\"], \"label\": dataset[\"label\"]})\n",
    "\n",
    "pd_data = load_dataset(\"train\")\n",
    "\n",
    "pd_data, pd_valid_data = train_test_split(\n",
    "    pd_data, test_size=0.2, random_state=0)\n",
    "\n",
    "START_INDEX = 0\n",
    "train_data = pd_data[NUM_TEST_SAMPLES:].reset_index(drop=True)\n",
    "test_data = pd_valid_data[:NUM_TEST_SAMPLES].reset_index(drop=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ef89ece1",
   "metadata": {},
   "source": [
    "Fetch a pre-trained huggingface model on the blbooksgenre dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "68f614d4",
   "metadata": {},
   "outputs": [],
   "source": [
    "BLBOOKSGENRE_MODEL_NAME = \"blbooksgenre_model\"\n",
    "NUM_LABELS = 2\n",
    "\n",
    "class FetchModel(object):\n",
    "    def __init__(self):\n",
    "        pass\n",
    "\n",
    "    def fetch(self):\n",
    "        zipfilename = BLBOOKSGENRE_MODEL_NAME + '.zip'\n",
    "        url = ('https://publictestdatasets.blob.core.windows.net/models/' +\n",
    "               BLBOOKSGENRE_MODEL_NAME + '.zip')\n",
    "        urlretrieve(url, zipfilename)\n",
    "        with zipfile.ZipFile(zipfilename, 'r') as unzip:\n",
    "            unzip.extractall(BLBOOKSGENRE_MODEL_NAME)\n",
    "\n",
    "def retrieve_blbooksgenre_model():\n",
    "    fetcher = FetchModel()\n",
    "    action_name = \"Model download\"\n",
    "    err_msg = \"Failed to download model\"\n",
    "    max_retries = 4\n",
    "    retry_delay = 60\n",
    "    retry_function(fetcher.fetch, action_name, err_msg,\n",
    "                   max_retries=max_retries,\n",
    "                   retry_delay=retry_delay)\n",
    "    model = AutoModelForSequenceClassification.from_pretrained(\n",
    "        BLBOOKSGENRE_MODEL_NAME, num_labels=NUM_LABELS)\n",
    "    return model\n",
    "\n",
    "model = retrieve_blbooksgenre_model()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0fbcba99",
   "metadata": {},
   "source": [
    "Load the model and tokenizer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dcd69889",
   "metadata": {},
   "outputs": [],
   "source": [
    "# load the model and tokenizer\n",
    "tokenizer = AutoTokenizer.from_pretrained(\"bert-base-uncased\")\n",
    "\n",
    "device = -1\n",
    "if device >= 0:\n",
    "    model = model.cuda()\n",
    "\n",
    "# build a pipeline object to do predictions\n",
    "pred = pipeline(\n",
    "    \"text-classification\",\n",
    "    model=model,\n",
    "    tokenizer=tokenizer,\n",
    "    device=device,\n",
    "    return_all_scores=True\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1d8f9d39",
   "metadata": {},
   "outputs": [],
   "source": [
    "from ml_wrappers import wrap_model\n",
    "wrapped_model = wrap_model(pred, test_data, 'text_classification')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e26e5068",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(\"number of errors on test dataset: \" + str(sum(wrapped_model.predict(test_data['text'].tolist()) != test_data['label'].tolist())))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8d90d728",
   "metadata": {},
   "outputs": [],
   "source": [
    "classes = train_data[\"label\"].unique()\n",
    "classes.sort()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f85d7ab1",
   "metadata": {},
   "source": [
    "### Create Model and Data Insights"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fc3d3dcf",
   "metadata": {},
   "outputs": [],
   "source": [
    "from responsibleai_text import RAITextInsights, ModelTask\n",
    "from raiwidgets import ResponsibleAIDashboard"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5e0de92b",
   "metadata": {},
   "source": [
    "To use Responsible AI Dashboard, initialize a RAITextInsights object upon which different components can be loaded.\n",
    "\n",
    "RAITextInsights accepts the model, the test dataset, the classes and the task type as its arguments."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ab932ee0",
   "metadata": {},
   "outputs": [],
   "source": [
    "rai_insights = RAITextInsights(pred, test_data,\n",
    "                               \"label\",\n",
    "                               task_type=ModelTask.TEXT_CLASSIFICATION,\n",
    "                               classes=classes)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "460472ac",
   "metadata": {},
   "source": [
    "Add the components of the toolbox for model assessment."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d7e14453",
   "metadata": {},
   "outputs": [],
   "source": [
    "rai_insights.explainer.add()\n",
    "rai_insights.error_analysis.add()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c1a81d6d",
   "metadata": {},
   "source": [
    "Once all the desired components have been loaded, compute insights on the test set."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d2416374",
   "metadata": {},
   "outputs": [],
   "source": [
    "rai_insights.compute()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e59be0a8",
   "metadata": {},
   "source": [
    "Finally, visualize and explore the model insights. Use the resulting widget or follow the link to view this in a new tab."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1712609a",
   "metadata": {},
   "outputs": [],
   "source": [
    "ResponsibleAIDashboard(rai_insights)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
